import os
import torch
import pandas as pd
import argparse
from PIL import Image
from tqdm import tqdm
import torchvision.transforms as T

# --- 1. Configuration ---
# Aligned with the "A. Data Preprocessing" section of the paper.
# Image settings
IMG_RESIZE = (256, 256)
IMG_PATCH_SIZE = 16
IMG_CHANNELS = 3

# Time-series settings
SIGNAL_LENGTH = 153600
SIGNAL_PATCH_SIZE = 60

# --- 2. Function Definitions ---

def image_to_patches(image_path):
    """Converts a PRPD image into a sequence of patches."""
    try:
        image = Image.open(image_path).convert('RGB')
    except FileNotFoundError:
        print(f"Warning: Image file not found at {image_path}, skipping.")
        return None

    transform = T.Compose([
        T.Resize(IMG_RESIZE),
        T.ToTensor(),
    ])
    img_tensor = transform(image)  # Shape: (C, H, W)

    # Unfold the image into non-overlapping patches (ViT-style)
    patches = img_tensor.unfold(1, IMG_PATCH_SIZE, IMG_PATCH_SIZE).unfold(2, IMG_PATCH_SIZE, IMG_PATCH_SIZE)
    # Shape: (C, NumPatches_H, NumPatches_W, Patch_H, Patch_W)
    
    patches = patches.permute(1, 2, 0, 3, 4).contiguous()
    # Shape: (NumPatches_H, NumPatches_W, C, Patch_H, Patch_W)

    num_patches = (IMG_RESIZE[0] // IMG_PATCH_SIZE) * (IMG_RESIZE[1] // IMG_PATCH_SIZE)
    patch_dim = IMG_CHANNELS * IMG_PATCH_SIZE * IMG_PATCH_SIZE
    
    patches_flattened = patches.view(num_patches, patch_dim)
    # Final Shape: (256, 768)
    return patches_flattened

def signal_to_patches(csv_path):
    """Converts a PD time-series CSV into a sequence of patches."""
    try:
        df = pd.read_csv(csv_path, header=None)
    except FileNotFoundError:
        print(f"Warning: CSV file not found at {csv_path}, skipping.")
        return None

    signal_data = df.values.flatten()

    # Validate data length
    if signal_data.shape[0] != SIGNAL_LENGTH:
        print(f"Warning: Signal length mismatch in {os.path.basename(csv_path)}. Expected {SIGNAL_LENGTH}, got {signal_data.shape[0]}. Skipping.")
        return None
    
    # Reshape the signal data into non-overlapping patches
    try:
        patches = signal_data.reshape(-1, SIGNAL_PATCH_SIZE)
        # Final Shape: (2560, 60)
    except ValueError:
        print(f"Warning: Could not reshape signal in {os.path.basename(csv_path)}. The total length must be divisible by the patch size. Skipping.")
        return None
        
    return torch.tensor(patches, dtype=torch.float32)

# --- 3. Main Execution Logic ---
def main(args):
    """Iterates through all raw data in the specified directory and performs preprocessing."""
    print("="*50)
    print("Starting Data Preprocessing")
    print(f"Raw Data Source: '{args.raw_dir}'")
    print(f"Processed Data Destination: '{args.save_dir}'")
    print("="*50)

    stats = {'processed': 0, 'skipped': 0, 'errors': 0}

    class_names = [d for d in os.listdir(args.raw_dir) if os.path.isdir(os.path.join(args.raw_dir, d))]
    
    for class_name in tqdm(class_names, desc="Processing classes"):
        class_raw_path = os.path.join(args.raw_dir, class_name)
        class_save_path = os.path.join(args.save_dir, class_name)
        os.makedirs(class_save_path, exist_ok=True)

        image_files = [f for f in os.listdir(class_raw_path) if f.endswith((".png", ".jpg", ".jpeg"))]

        for img_filename in tqdm(image_files, desc=f"  - {class_name}", leave=False):
            base_name = os.path.splitext(img_filename)[0]
            img_path = os.path.join(class_raw_path, img_filename)
            csv_path = os.path.join(class_raw_path, base_name + ".csv")

            # Define save paths for the processed tensors
            save_path_img = os.path.join(class_save_path, base_name + "_img.pt")
            save_path_sig = os.path.join(class_save_path, base_name + "_sig.pt")

            # Skip if the files have already been processed
            if os.path.exists(save_path_img) and os.path.exists(save_path_sig):
                stats['skipped'] += 1
                continue
            
            # Skip if the corresponding CSV file is missing
            if not os.path.exists(csv_path):
                print(f"Warning: Corresponding CSV for {img_filename} not found. Skipping.")
                stats['errors'] += 1
                continue

            # Convert each modality into patches
            img_patches = image_to_patches(img_path)
            sig_patches = signal_to_patches(csv_path)

            # Save the patches only if both conversions succeed
            if img_patches is not None and sig_patches is not None:
                torch.save(img_patches, save_path_img)
                torch.save(sig_patches, save_path_sig)
                stats['processed'] += 1
            else:
                stats['errors'] += 1
                
    print("\n" + "="*50)
    print("Preprocessing Complete!")
    print(f"  - Successfully processed: {stats['processed']} samples")
    print(f"  - Skipped (already exist): {stats['skipped']} samples")
    print(f"  - Errors / Not Found: {stats['errors']} samples")
    print("="*50)

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Preprocess PRPD images and PD signals into tensor patches.")
    parser.add_argument('--raw_dir', type=str, default='data/raw', help='Directory containing the raw data, organized by class.')
    parser.add_argument('--save_dir', type=str, default='data/processed', help='Directory to save the processed tensor files.')
    
    args = parser.parse_args()
    main(args)
